import torch
from torch.utils.data import Subset
from typing import List
from sklearn.neighbors import kneighbors_graph
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
import math
import random

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")

def fmt_nf(x: float) -> str:
    return f"{x:.3f}".rstrip("0").rstrip(".").replace(".", "p")
    

def gaussian_kernel(A: torch.Tensor, B: torch.Tensor, sigma: float = 1/(28*28)) -> torch.Tensor:
    
    diff = A.unsqueeze(1) - B.unsqueeze(0)
    return torch.exp(-sigma * torch.norm(diff, dim=2) ** 2)

def ECMMD(Z: List[float], Y: List[float], X: List[float], kernel, neighbors: int) -> float:

    batch = X.shape[0]

    # neighbors matrix
    N_X = kneighbors_graph(X.cpu().numpy(), neighbors, include_self=False).toarray()
    N_X = torch.tensor(N_X, dtype=torch.float32)

    # kernel matrices
    kernel_ZZ = kernel(Z, Z)
    kernel_YY = kernel(Y, Y)
    kernel_ZY = kernel(Z, Y)
    kernel_YZ = kernel(Y, Z)

    # H matrix
    H = kernel_ZZ + kernel_YY - kernel_ZY - kernel_YZ

    # ECMMD
    return torch.sum(H.cpu() * N_X) / (batch * neighbors)

def split_dataset_by_class(dataset, train_samples_per_class, val_samples_per_class):
    train_indices = []
    val_indices = []
    class_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
        
    for cls, indices in class_indices.items():
        indices = np.array(indices)
        np.random.shuffle(indices)
        
        train_indices.extend(indices[:train_samples_per_class])
        val_indices.extend(indices[train_samples_per_class:train_samples_per_class + val_samples_per_class])
    
    train_indices = torch.tensor(np.random.permutation(train_indices))
    val_indices = torch.tensor(np.random.permutation(val_indices))
    
    return Subset(dataset, train_indices), Subset(dataset, val_indices)

def select_samples_by_class(dataset, num_samples_per_class = 500):
    
    class_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(dataset):
        class_indices[label].append(idx)
    
    selected_indices = []
    for indices in class_indices.values():
        selected_indices.extend(np.random.choice(indices, num_samples_per_class, replace=False))   
    selected_indices = torch.tensor(np.random.permutation(selected_indices))
    return Subset(dataset, selected_indices)

def train_model(model, 
                train_dataloader, 
                validation_dataloader, 
                test_images, 
                noisy_test_images, 
                test_eta, 
                optimizer, 
                ECMMD, 
                gaussian_kernel, 
                NEIGHBORS, 
                NUM_EPOCH, 
                plot_idx=0):
    
    device = get_device()
    model.to(device)
    
    train_losses = []
    val_losses = []
    
    for epoch in tqdm(range(NUM_EPOCH)):
        model.train()
        total_train_loss = 0.0
        num_train_batches = 0
        
        for noisy_train_images, train_eta, train_images in train_dataloader:
            optimizer.zero_grad()
            noisy_train_images = noisy_train_images.to(device)
            train_eta = train_eta.to(device)
            train_images = train_images.to(device)
            
            denoised_train_images = model(noisy_train_images, train_eta)
            loss = (ECMMD(
                        denoised_train_images.reshape(len(denoised_train_images), -1),
                        train_images.reshape(len(train_images), -1),
                        noisy_train_images.reshape(len(noisy_train_images), -1),
                        kernel=gaussian_kernel,
                        neighbors=NEIGHBORS
                    ) ** 2)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            num_train_batches += 1
        
        avg_train_loss = total_train_loss / num_train_batches if num_train_batches > 0 else 0
        train_losses.append(avg_train_loss)
        
        model.eval()
        total_val_loss = 0.0
        num_val_batches = 0
        with torch.inference_mode():
            for noisy_validation_images, validation_eta, validation_images in validation_dataloader:
                noisy_validation_images = noisy_validation_images.to(device)
                validation_eta = validation_eta.to(device)
                validation_images = validation_images.to(device)
                
                denoised_validation_images = model(noisy_validation_images, validation_eta)
                val_loss = (ECMMD(
                                denoised_validation_images.reshape(len(denoised_validation_images), -1),
                                validation_images.reshape(len(validation_images), -1),
                                noisy_validation_images.reshape(len(noisy_validation_images), -1),
                                kernel=gaussian_kernel,
                                neighbors=NEIGHBORS
                            ) ** 2)
                total_val_loss += val_loss.item()
                num_val_batches += 1
        
        avg_val_loss = total_val_loss / num_val_batches if num_val_batches > 0 else 0
        val_losses.append(avg_val_loss)
        
        if epoch % 20 == 0:
            print(f'Epoch {epoch}, Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')
            with torch.no_grad():
                test_noisy = noisy_test_images[plot_idx].unsqueeze(0).to(device)
                test_eta_batch = test_eta[plot_idx].unsqueeze(0).to(device)
                temp_img = model(test_noisy, test_eta_batch).cpu()
                
                plt.figure(figsize=(5, 3))
                plt.subplot(1, 3, 1)
                plt.imshow(test_images[plot_idx].cpu().squeeze(), cmap='gray')
                plt.title('Actual')
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(noisy_test_images[plot_idx].cpu().squeeze(), cmap='gray')
                plt.title('Low-res')
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(temp_img.cpu().squeeze(), cmap='gray')
                plt.title('High-res')
                plt.axis('off')
                
                plt.show()
    
    return model, train_losses, val_losses
